// Copyright 2020 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iostream>
#include <limits>
#include <random>

#include "include/xnnpack.h"

// align a size up to XNN_EXTRA_BYTES
#define XNN_PAD_EXTRA_BYTES(s, t) \
  (((s) + XNN_EXTRA_BYTES / sizeof(t) - 1) & ~(XNN_EXTRA_BYTES / sizeof(t) - 1))

namespace models {

xnn_subgraph_t FP32Attention(size_t b, size_t t, size_t h, size_t n, size_t s) {
  xnn_status status;
  xnn_subgraph_t subgraph = nullptr;
  status = xnn_create_subgraph(/*num_external_values=*/4, 0, &subgraph);
  if (status != xnn_status_success) {
    std::cerr << "failed to create subgrpah" << std::endl;
    return nullptr;
  }

  std::random_device random_device;
  auto rng = std::mt19937(random_device());

  uint32_t v0 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v0_dims = {{b, s, n, t}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v0_dims.size(), v0_dims.data(),
      /*data=*/nullptr, 0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &v0);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v0" << std::endl;
    return nullptr;
  }

  uint32_t v1 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v1_dims = {{b, t, n, h}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v1_dims.size(), v1_dims.data(),
      /*data=*/nullptr, 1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &v1);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v1" << std::endl;
    return nullptr;
  }

  uint32_t v2 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v2_dims = {{b, s, n, h}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v2_dims.size(), v2_dims.data(),
      /*data=*/nullptr, 2, XNN_VALUE_FLAG_EXTERNAL_INPUT, &v2);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v2" << std::endl;
    return nullptr;
  }

  uint32_t v3 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v3_dims = {{b, t, n, h}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v3_dims.size(), v3_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v3);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v3" << std::endl;
    return nullptr;
  }

  uint32_t v4 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v4_dims = {{b, n, t, h}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v4_dims.size(), v4_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v4);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v4" << std::endl;
    return nullptr;
  }

  uint32_t v5 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v5_dims = {{b, n, s, h}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v5_dims.size(), v5_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v5);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v5" << std::endl;
    return nullptr;
  }

  uint32_t v6 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v6_dims = {{b, n, t, s}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v6_dims.size(), v6_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v6);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v6" << std::endl;
    return nullptr;
  }

  uint32_t v7 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v7_dims = {{b, n, t, s}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v7_dims.size(), v7_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v7);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v7" << std::endl;
    return nullptr;
  }

  uint32_t v8 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v8_dims = {{b, n, t, s}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v8_dims.size(), v8_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v8);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v8" << std::endl;
    return nullptr;
  }

  uint32_t v9 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v9_dims = {{b, n, t, s}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v9_dims.size(), v9_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v9);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v9" << std::endl;
    return nullptr;
  }

  uint32_t v10 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v10_dims = {{b, n, t, s}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v10_dims.size(), v10_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v10);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v10" << std::endl;
    return nullptr;
  }

  uint32_t v11 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v11_dims = {{b, n, t, t}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v11_dims.size(), v11_dims.data(),
      /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v11);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v11" << std::endl;
    return nullptr;
  }

  uint32_t v12 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 4> v12_dims = {{b, t, n, t}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, v12_dims.size(), v12_dims.data(),
      /*data=*/nullptr, 3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &v12);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor v12" << std::endl;
    return nullptr;
  }

  alignas(16) static std::array<float, XNN_PAD_EXTRA_BYTES(1, float)> w13_data;
  uint32_t w13 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 1> w13_dims = {{1}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, w13_dims.size(), w13_dims.data(),
      /*data=*/w13_data.data(), XNN_INVALID_VALUE_ID, /*flags=*/0, &w13);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor w13" << std::endl;
    return nullptr;
  }

  static const std::array<int32_t, 4> w14_data = {
      0x000000,
      0x000002,
      0x000001,
      0x000003,
  };
  uint32_t w14 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 1> w14_dims = {{4}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_int32, w14_dims.size(), w14_dims.data(),
      /*data=*/w14_data.data(), XNN_INVALID_VALUE_ID, /*flags=*/0, &w14);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor w14" << std::endl;
    return nullptr;
  }

  alignas(16) static std::array<float, XNN_PAD_EXTRA_BYTES(1, float)> w15_data;
  uint32_t w15 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 1> w15_dims = {{1}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_fp32, w15_dims.size(), w15_dims.data(),
      /*data=*/w15_data.data(), XNN_INVALID_VALUE_ID, /*flags=*/0, &w15);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor w15" << std::endl;
    return nullptr;
  }

  static const std::array<int32_t, 4> w16_data = {
      0x000000,
      0x000002,
      0x000003,
      0x000001,
  };
  uint32_t w16 = XNN_INVALID_VALUE_ID;
  std::array<size_t, 1> w16_dims = {{4}};
  status = xnn_define_tensor_value(
      subgraph, xnn_datatype_int32, w16_dims.size(), w16_dims.data(),
      /*data=*/w16_data.data(), XNN_INVALID_VALUE_ID, /*flags=*/0, &w16);
  if (status != xnn_status_success) {
    std::cerr << "failed to create tensor w16" << std::endl;
    return nullptr;
  }

  auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, +1.0f),
                          std::ref(rng));
  std::generate(w13_data.begin(), w13_data.end(), std::ref(f32rng));
  std::generate(w15_data.begin(), w15_data.end(), std::ref(f32rng));

  xnn_binary_params v3_params = {-std::numeric_limits<float>::infinity(),
                                 std::numeric_limits<float>::infinity()};
  status =
      xnn_define_binary(subgraph, xnn_binary_multiply, &v3_params, v1, w13, v3,
                        /*flags=*/0);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #0" << std::endl;
    return nullptr;
  }

  std::array<size_t, 4> perm_v3_w14_v4 = {
      (size_t)w14_data[0], (size_t)w14_data[1], (size_t)w14_data[2],
      (size_t)w14_data[3]};
  status =
      xnn_define_static_transpose(subgraph,
                                  /*num_dims=*/perm_v3_w14_v4.size(),
                                  /*perm=*/perm_v3_w14_v4.data(), v3, v4, 0);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #1" << std::endl;
    return nullptr;
  }

  std::array<size_t, 4> perm_v2_w14_v5 = {
      (size_t)w14_data[0], (size_t)w14_data[1], (size_t)w14_data[2],
      (size_t)w14_data[3]};
  status =
      xnn_define_static_transpose(subgraph,
                                  /*num_dims=*/perm_v2_w14_v5.size(),
                                  /*perm=*/perm_v2_w14_v5.data(), v2, v5, 0);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #2" << std::endl;
    return nullptr;
  }

  status = xnn_define_batch_matrix_multiply(subgraph, v4, v5, v6,
                                            XNN_FLAG_TRANSPOSE_B);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #3" << std::endl;
    return nullptr;
  }

  xnn_binary_params v7_params = {-std::numeric_limits<float>::infinity(),
                                 std::numeric_limits<float>::infinity()};
  status =
      xnn_define_binary(subgraph, xnn_binary_multiply, &v7_params, v6, w15, v7,
                        /*flags=*/0);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #4" << std::endl;
    return nullptr;
  }

  status = xnn_define_unary(subgraph, xnn_unary_tanh,
                            /*params=*/nullptr, v7, v8, 0);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #5" << std::endl;
    return nullptr;
  }

  status = xnn_define_softmax(subgraph, v8, v9,
                              /*flags=*/0);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #6" << std::endl;
    return nullptr;
  }

  std::array<size_t, 4> perm_v0_w16_v10 = {
      (size_t)w16_data[0], (size_t)w16_data[1], (size_t)w16_data[2],
      (size_t)w16_data[3]};
  status =
      xnn_define_static_transpose(subgraph,
                                  /*num_dims=*/perm_v0_w16_v10.size(),
                                  /*perm=*/perm_v0_w16_v10.data(), v0, v10, 0);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #7" << std::endl;
    return nullptr;
  }

  status = xnn_define_batch_matrix_multiply(subgraph, v9, v10, v11,
                                            XNN_FLAG_TRANSPOSE_B);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #8" << std::endl;
    return nullptr;
  }

  std::array<size_t, 4> perm_v11_w14_v12 = {
      (size_t)w14_data[0], (size_t)w14_data[1], (size_t)w14_data[2],
      (size_t)w14_data[3]};
  status = xnn_define_static_transpose(subgraph,
                                       /*num_dims=*/perm_v11_w14_v12.size(),
                                       /*perm=*/perm_v11_w14_v12.data(), v11,
                                       v12, 0);
  if (status != xnn_status_success) {
    std::cerr << "failed to create node #9" << std::endl;
    return nullptr;
  }

  return subgraph;
}

}  // namespace models
